
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
    
    
def data_high_generator(samplesize,env):
    
    labels = np.random.randint(0, 10, samplesize)
    print(labels)
    x = np.random.normal(loc=0, scale=1.0, size=samplesize)
    x[labels==0] = np.random.normal(loc=-180, scale=20.0, size=x[labels==0].shape)
    x[labels==5] = np.random.normal(loc=-140, scale=20.0, size=x[labels==5].shape)
    x[labels==1] = np.random.normal(loc=-100, scale=20.0, size=x[labels==1].shape)
    x[labels==6] = np.random.normal(loc=-60, scale=20.0, size=x[labels==6].shape)
    x[labels==2] = np.random.normal(loc=-20, scale=20.0, size=x[labels==2].shape)
    x[labels==7] = np.random.normal(loc=20, scale=20.0, size=x[labels==7].shape)
    x[labels==3] = np.random.normal(loc=60, scale=20.0, size=x[labels==3].shape)
    x[labels==8] = np.random.normal(loc=100, scale=20.0, size=x[labels==8 ].shape)
    x[labels==4] = np.random.normal(loc=140, scale=20.0, size=x[labels==4].shape)
    x[labels==9] = np.random.normal(loc=180, scale=20.0, size=x[labels==9].shape)
    z = np.random.normal(loc=0, scale=1.0, size=samplesize)
    z[labels==0] = np.random.normal(loc=-5*env, scale=30.0, size=x[labels==0].shape)
    z[labels==5] = np.random.normal(loc=4*env, scale=30.0, size=x[labels==5].shape)
    z[labels==1] = np.random.normal(loc=-3*env, scale=30.0, size=x[labels==1].shape)
    z[labels==6] = np.random.normal(loc=2*env, scale=30.0, size=x[labels==6].shape)
    z[labels==2] = np.random.normal(loc=-1*env, scale=30.0, size=x[labels==2].shape)
    z[labels==7] = np.random.normal(loc=1*env, scale=30.0, size=x[labels==7].shape)
    z[labels==3] = np.random.normal(loc=-2*env, scale=30.0, size=x[labels==3].shape)
    z[labels==8] = np.random.normal(loc=3*env, scale=30.0, size=x[labels==8].shape)
    z[labels==4] = np.random.normal(loc=-4*env, scale=30.0, size=x[labels==4].shape)
    z[labels==9] = np.random.normal(loc=5*env, scale=30.0, size=x[labels==9].shape)





    labels=torch.from_numpy(labels[:, None]).float()
    labels = (labels > 4).float()
    
    w= np.concatenate(([x.reshape(samplesize,1),z.reshape(samplesize,1)]),axis=1)
    
        
    return {
                  'images': torch.from_numpy(w.astype(np.float32))/150,
                  'labels': labels
                }


def data_generator(samplesize,env):  #Best passible=84%
    labels = np.random.randint(0, 10, samplesize)
    x = np.random.normal(loc=0, scale=1.0, size=samplesize)
    x[labels==0] = np.random.normal(loc=-180, scale=20.0, size=x[labels==0].shape)
    x[labels==5] = np.random.normal(loc=-140, scale=20.0, size=x[labels==5].shape)
    x[labels==1] = np.random.normal(loc=-100, scale=20.0, size=x[labels==1].shape)
    x[labels==6] = np.random.normal(loc=-60, scale=20.0, size=x[labels==6].shape)
    x[labels==2] = np.random.normal(loc=-20, scale=20.0, size=x[labels==2].shape)
    x[labels==7] = np.random.normal(loc=20, scale=20.0, size=x[labels==7].shape)
    x[labels==3] = np.random.normal(loc=60, scale=20.0, size=x[labels==3].shape)
    x[labels==8] = np.random.normal(loc=100, scale=20.0, size=x[labels==8 ].shape)
    x[labels==4] = np.random.normal(loc=140, scale=20.0, size=x[labels==4].shape)
    x[labels==9] = np.random.normal(loc=180, scale=20.0, size=x[labels==9].shape)
 
    z = np.random.normal(loc=0, scale=1.0, size=samplesize)
    z[labels==0] = np.random.normal(loc=-5*env, scale=30.0, size=x[labels==0].shape)
    z[labels==5] = np.random.normal(loc=4*env, scale=30.0, size=x[labels==5].shape)
    z[labels==1] = np.random.normal(loc=-3*env, scale=30.0, size=x[labels==1].shape)
    z[labels==6] = np.random.normal(loc=2*env, scale=30.0, size=x[labels==6].shape)
    z[labels==2] = np.random.normal(loc=-1*env, scale=30.0, size=x[labels==2].shape)
    z[labels==7] = np.random.normal(loc=1*env, scale=30.0, size=x[labels==7].shape)
    z[labels==3] = np.random.normal(loc=-2*env, scale=30.0, size=x[labels==3].shape)
    z[labels==8] = np.random.normal(loc=3*env, scale=30.0, size=x[labels==8].shape)
    z[labels==4] = np.random.normal(loc=-4*env, scale=30.0, size=x[labels==4].shape)
    z[labels==9] = np.random.normal(loc=5*env, scale=30.0, size=x[labels==9].shape)



    labels=torch.from_numpy(labels[:, None]).float()
    #print(labels)
    
    w= np.concatenate(([x.reshape(samplesize,1),z.reshape(samplesize,1)]),axis=1)
    
        
    return {
                  'images': torch.from_numpy(w.astype(np.float32))/150,
                  'labels': labels
                }

def data_loader(envs, sample_size):
    sample = []
    fig = plt.figure(figsize = (15, 13))
    fig.suptitle("Training and Test data on e={}.png".format(envs))
    
    #Train data generating
    data_train = data_generator(sample_size,envs)
    ax = fig.add_subplot(1,2, 1)
    ax.title.set_text('Training data')
    ax.set_ylim(-350, 350) 
    ax.set_xlim(-150, 150) 
    w = data_train['images']*150
    labels = data_train['labels']
    ax.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
    ax.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
    ax.scatter(w[labels.reshape(sample_size)==2][:,0],w[labels.reshape(sample_size)==2][:,1],color="g",label="label:2",s=10)
    ax.scatter(w[labels.reshape(sample_size)==3][:,0],w[labels.reshape(sample_size)==3][:,1],color="c",label="label:3",s=10)
    ax.scatter(w[labels.reshape(sample_size)==4][:,0],w[labels.reshape(sample_size)==4][:,1],color="m",label="label:4",s=10)
    ax.scatter(w[labels.reshape(sample_size)==5][:,0],w[labels.reshape(sample_size)==5][:,1],color="y",label="label:5",s=10)
    ax.scatter(w[labels.reshape(sample_size)==6][:,0],w[labels.reshape(sample_size)==6][:,1],color="g",label="label:6",s=10)
    ax.scatter(w[labels.reshape(sample_size)==7][:,0],w[labels.reshape(sample_size)==7][:,1],color="c",label="label:7",s=10)
    ax.scatter(w[labels.reshape(sample_size)==8][:,0],w[labels.reshape(sample_size)==8][:,1],color="m",label="label:8",s=10)
    ax.scatter(w[labels.reshape(sample_size)==9][:,0],w[labels.reshape(sample_size)==9][:,1],color="y",label="label:9",s=10)

    sample.append(data_train)
    
    #Test data generating
    data_test = data_generator(sample_size,-envs)
    bx = fig.add_subplot(1,2, 2)
    bx.title.set_text('Test data')
    ax.set_ylim(-350, 350) 
    ax.set_xlim(-150, 150) 
    w = data_test['images']*150
    labels = data_test['labels']
    bx.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
    bx.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
    bx.scatter(w[labels.reshape(sample_size)==2][:,0],w[labels.reshape(sample_size)==2][:,1],color="g",label="label:2",s=10)
    bx.scatter(w[labels.reshape(sample_size)==3][:,0],w[labels.reshape(sample_size)==3][:,1],color="c",label="label:3",s=10)
    bx.scatter(w[labels.reshape(sample_size)==4][:,0],w[labels.reshape(sample_size)==4][:,1],color="m",label="label:4",s=10)
    bx.scatter(w[labels.reshape(sample_size)==5][:,0],w[labels.reshape(sample_size)==5][:,1],color="y",label="label:5",s=10)
    bx.scatter(w[labels.reshape(sample_size)==6][:,0],w[labels.reshape(sample_size)==6][:,1],color="r",label="label:6",s=10)
    bx.scatter(w[labels.reshape(sample_size)==7][:,0],w[labels.reshape(sample_size)==7][:,1],color="g",label="label:7",s=10)
    bx.scatter(w[labels.reshape(sample_size)==8][:,0],w[labels.reshape(sample_size)==8][:,1],color="c",label="label:8",s=10)
    bx.scatter(w[labels.reshape(sample_size)==9][:,0],w[labels.reshape(sample_size)==9][:,1],color="m",label="label:9",s=10)
    

    sample.append(data_test)
    
    
    
    plt.legend()
    #fig.savefig("Training and Test data on e={}.png".format(envs))
    plt.show()
    return sample


def high_data_loader(env_number, envs_array, sample_size):
    sample = []
    fig = plt.figure(figsize = (15, 13))
    fig.suptitle("Visualization of high data under env={}".format(envs_array))
    for i in range(envs_array.shape[0]):
        data = data_high_generator(sample_size,envs_array[i])
        #data['env_labels'] = data['labels'] + 2*i
        
        ax = fig.add_subplot(3,2, i+1)
        ax.set_ylim(-300, 300) 
        ax.set_xlim(-100, 100) 
        w = data['images']*150
        labels = data['labels']
        #print(labels)
        ax.scatter(w[labels.reshape(sample_size)==0][:,0],w[labels.reshape(sample_size)==0][:,1],color="b",label="label:0",s=10)
        ax.scatter(w[labels.reshape(sample_size)==1][:,0],w[labels.reshape(sample_size)==1][:,1],color="r",label="label:1",s=10)
        
        sample.append(data)
    plt.legend()
    #fig.savefig("high_data env={}.png".format(envs_array))
    plt.show()
    return sample




    